import numpy as np
import argparse
import random
import torch


def init_seeds():
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)


def get_args():
    parser = argparse.ArgumentParser(description="Training with CIFAR dataset")

    parser.add_argument(
        "--dataset",
        type=str,
        choices=["cifar10", "cifar100"],
        default="cifar10",
        help="Choose between cifar10 and cifar100",
    )

    parser.add_argument("--num_clients", type=int, default=10, help="Number of clients")

    parser.add_argument(
        "--alpha",
        type=float,
        default=0.5,
        help="Dirichlet distribution alpha for non-IID partitioning",
    )

    parser.add_argument(
        "--batch_size", type=int, default=64, help="Batch size for each client"
    )

    parser.add_argument(
        "--frac",
        type=float,
        default=0.1,
        help="Fraction of clients selected in each round",
    )

    parser.add_argument(
        "--comm_rounds",
        type=int,
        default=50,
        help="Number of communication rounds",
    )

    parser.add_argument(
        "--local_epochs",
        type=int,
        default=5,
        help="Number of training epochs per client per round",
    )

    args = parser.parse_args()

    args.num_classes = 10 if args.dataset == "cifar10" else 100

    return args
